{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d5e5b9f",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "# Multi-Head Attention\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3cf184a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:42:05.720112Z",
     "iopub.status.busy": "2023-08-18T19:42:05.719157Z",
     "iopub.status.idle": "2023-08-18T19:42:08.696163Z",
     "shell.execute_reply": "2023-08-18T19:42:08.694248Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ff77b7",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Choose the scaled dot product attention\n",
    "for each head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bdc6ec21",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:42:08.702133Z",
     "iopub.status.busy": "2023-08-18T19:42:08.701108Z",
     "iopub.status.idle": "2023-08-18T19:42:08.710324Z",
     "shell.execute_reply": "2023-08-18T19:42:08.709226Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MultiHeadAttention(d2l.Module):  \n",
    "    \"\"\"Multi-head attention.\"\"\"\n",
    "    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.attention = d2l.DotProductAttention(dropout)\n",
    "        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)\n",
    "        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)\n",
    "        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)\n",
    "        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)\n",
    "\n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        queries = self.transpose_qkv(self.W_q(queries))\n",
    "        keys = self.transpose_qkv(self.W_k(keys))\n",
    "        values = self.transpose_qkv(self.W_v(values))\n",
    "\n",
    "        if valid_lens is not None:\n",
    "            valid_lens = torch.repeat_interleave(\n",
    "                valid_lens, repeats=self.num_heads, dim=0)\n",
    "\n",
    "        output = self.attention(queries, keys, values, valid_lens)\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50441e5",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Parallel computation of multiple heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4da9ac2c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:42:08.714864Z",
     "iopub.status.busy": "2023-08-18T19:42:08.714207Z",
     "iopub.status.idle": "2023-08-18T19:42:08.723031Z",
     "shell.execute_reply": "2023-08-18T19:42:08.722024Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "@d2l.add_to_class(MultiHeadAttention)  \n",
    "def transpose_qkv(self, X):\n",
    "    \"\"\"Transposition for parallel computation of multiple attention heads.\"\"\"\n",
    "    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "    return X.reshape(-1, X.shape[2], X.shape[3])\n",
    "\n",
    "@d2l.add_to_class(MultiHeadAttention)  \n",
    "def transpose_output(self, X):\n",
    "    \"\"\"Reverse the operation of transpose_qkv.\"\"\"\n",
    "    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "    return X.reshape(X.shape[0], X.shape[1], -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d6644d7",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Test our implemented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "18451bc5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:42:08.726610Z",
     "iopub.status.busy": "2023-08-18T19:42:08.726042Z",
     "iopub.status.idle": "2023-08-18T19:42:08.759713Z",
     "shell.execute_reply": "2023-08-18T19:42:08.758787Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "num_hiddens, num_heads = 100, 5\n",
    "attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)\n",
    "batch_size, num_queries, num_kvpairs = 2, 4, 6\n",
    "valid_lens = torch.tensor([3, 2])\n",
    "X = torch.ones((batch_size, num_queries, num_hiddens))\n",
    "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n",
    "d2l.check_shape(attention(X, Y, Y, valid_lens),\n",
    "                (batch_size, num_queries, num_hiddens))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "required_libs": [],
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}